Explaining a Question Answering Transformers Model
Here we demonstrate how to explain the output of a question answering model that predicts which range of the context text contains the answer to a given question.
[1]:
import numpy as np
import torch
import transformers
import shap
# load the model
pmodel = transformers.pipeline("question-answering")
tokenized_qs = None # variable to store the tokenized data
# define two predictions, one that outputs the logits for the range start,
# and the other for the range end
def f(questions, tokenized_qs, start):
outs = []
for q in questions:
idx = np.argwhere(np.array(tokenized_qs["input_ids"]) == pmodel.tokenizer.sep_token_id)[
0, 0
] # this code assumes that there is only one sentence in data
d = tokenized_qs.copy()
d["input_ids"][:idx] = q[:idx]
d["input_ids"][idx + 1 :] = q[idx + 1 :]
out = pmodel.model.forward(**{k: torch.tensor(d[k]).reshape(1, -1) for k in d})
logits = out.start_logits if start else out.end_logits
outs.append(logits.reshape(-1).detach().numpy())
return outs
def tokenize_data(data):
for q in data:
question, context = q.split("[SEP]")
tokenized_data = pmodel.tokenizer(question, context)
return tokenized_data # this code assumes that there is only one sentence in data
def f_start(questions):
return f(questions, tokenized_qs, True)
def f_end(questions):
return f(questions, tokenized_qs, False)
# attach a dynamic output_names property to the models so we can plot the tokens at each output position
def out_names(inputs):
question, context = inputs.split("[SEP]")
d = pmodel.tokenizer(question, context)
return [pmodel.tokenizer.decode([id]) for id in d["input_ids"]]
f_start.output_names = out_names
f_end.output_names = out_names
Explain the starting positions
Here we explain the starting range predictions of the model. Note that because the model output depends on the length of the model input, is is important that we pass the model’s native tokenizer for masking, so that when we hide portions of the text we can retain the same number of tokens and hence the same meaning for each output position.
[2]:
data = [
"What is on the table?[SEP]When I got home today I saw my cat on the table, and my frog on the floor.",
] # this code assumes that there is only one sentence in data
tokenized_qs = tokenize_data(data)
explainer_start = shap.Explainer(f_start, shap.maskers.Text(tokenizer=pmodel.tokenizer, output_type="ids"))
shap_values_start = explainer_start(data)
shap.plots.text(shap_values_start)
Partition explainer: 2it [00:32, 32.86s/it]
Explain the ending positions
This is the same process as above, but now we explain the end tokens.
[3]:
explainer_end = shap.Explainer(f_end, pmodel.tokenizer)
shap_values_end = explainer_end(data)
shap.plots.text(shap_values_end)
Explain a matching function
In the example above we directly explained the output logits coming from the model. This required us to ensure that we only perturbed the input in length-preserving ways, so as to not change the meaning of the output logits. A less detailed but more flexible approach is to just score if specific answers are produced by the model.
[4]:
def make_answer_scorer(answers):
def f(questions):
out = []
for q in questions:
question, context = q.split("[SEP]")
results = pmodel(question, context, topk=20)
values = []
for answer in answers:
value = 0
for result in results:
if result["answer"] == answer:
value = result["score"]
break
values.append(value)
out.append(values)
return out
f.output_names = answers
return f
f_answers = make_answer_scorer(["my cat", "cat", "my frog"])
explainer_answers = shap.Explainer(f_answers, pmodel.tokenizer)
shap_values_answers = explainer_answers(data)
shap.plots.text(shap_values_answers)
Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!